from __future__ import division
import numpy as np
# import pykalman
import matplotlib.pyplot as plt
import scipy.stats as stats
import ekf_testv6 as ekf
plt.rcParams["font.family"] = "Arial"
from sklearn import datasets, linear_model
from sklearn.linear_model import LinearRegression
import statsmodels.api as sm
import matplotlib.transforms as mtransforms

def colorgen(c,step,const):
    y=(((c+1)%3-1)*step[0], ((c//3+1)%3-1)*step[1], ((c//9+1)%3-1)*step[2])
    color=tuple( max(min(sum(t), 255),0) / 255 for t in zip(y,const))
    return color #

data_simulation = np.genfromtxt(open("all_nl_cesm2_ts_0000-0089.csv", "rb"),dtype=float, delimiter=',')

twTR=np.where(data_simulation == 0, float("nan"), data_simulation)

[sms, smyrs]=np.shape(twTR)
#dif=0.21
twTRmean=np.nanmean(twTR,axis=0)
numsimsn=np.zeros(smyrs)
for j in range(smyrs):
    numsimsn[j]=len([i for i in twTR[:,j] if i > 0])
twTRstd=np.nanstd(twTR,axis=0)
twTRstd2=np.nanstd(twTR,axis=0) /np.sqrt(numsimsn)
smdate=1850

sdate=1850
#data = np.genfromtxt(open("HadCRUT5.global.annual.csv", "rb"),dtype=float, delimiter=',')
dates=np.arange(sdate,sdate+smyrs)
temps=ekf.temps
offset= -np.mean(temps[(1960-sdate):(1990-sdate)]) +ekf.JonesOffset + 273.15 #Jones2013 13.7 to 14 (Jones1999)
print(offset)
temps=temps+offset

modeloffset= -np.mean(twTRmean[(1850-smdate):(1950-smdate)]) +np.mean(temps[(1850-sdate):(1950-sdate)])
twTRmean=twTRmean+modeloffset
twTR=twTR+modeloffset

qqy=np.zeros([sms, smyrs])
for k in range(smyrs):
    qqy[:,k]=(twTR[:,k]-twTRmean[k]) #/twTRstd[k]

#twTR=np.concatenate((twTR,[temps[70:156]]), axis=0)

dateslice=dates[(smdate-sdate):(smdate-sdate+ekf.n_iters)]
        
def plot_many_sims():
    plt.plot(dates,twTR[0,:],'-',label='single (unforced) model runs $(Y_n)_j$', color=ekf.coloruncert, linewidth=0.5,zorder=1)
    for r in range(1,sms):
        plt.plot(dates,twTR[r,:],'-', color=colorgen(r,(70,30,30),(52,235,235)), linewidth=0.5,zorder=1) #(52./255, 235./255, 235./255)
        
    plt.plot(dates,twTRmean,'-',label=str(sms)+'-member ensemble average $\overline{(Y_n)_j}$', color=ekf.colorekf,linewidth=0.5,zorder=4)
    ekf.plot_boilerplate()

    
    
cesmcolor='darkturquoise'
nbins=50
axlabels=['a)','b)','c)','d)']
if __name__ == "__main__":
    plt.figure(0,figsize=(7, 6))
    plot_many_sims()
    plt.title('CESM2 LENS Model Simulations\nGlobal Annual Averages of Temperature at Surface (TS)')
    plt.fill_between(dates[(smdate-sdate):(smdate-sdate+smyrs)], twTRmean-2*twTRstd2, twTRmean+2*twTRstd2, \
                 alpha=1, label=str(sms)+'-member ensemble standard error * 2', color=ekf.colorstate,zorder=2)
    plt.plot(dateslice,temps,'o',label='real measurements (HadCRUT5) $Y_n$',markersize=2,color=ekf.colorgrey,zorder=3)
    plt.plot([2014,2014],[280,300],'k--',zorder=3)
    plt.xlim(1850,2100)
    plt.ylim(286.1,291.5)
    plt.yticks(np.arange(286.2,291.5,0.4))
    plt.xticks(np.arange(1850,2101,25))
    plt.text(1880,287.8,"Historical Hindcast")
    plt.text(2040,287.8,"SSP370 Projection")
    ax=plt.gca()
    ax.set_ylabel('Temperature (K)')
    ax.set_xlabel('Year')
    #plt.text(2040,290,"UCAR maintaince \n - will fill in shortly",style='italic')
    handles, labels = plt.gca().get_legend_handles_labels()
    order = [0,1,3,2]
    plt.legend([handles[idx] for idx in order],[labels[idx] for idx in order]) #,prop={'size': 8})
    axb = ax.twinx()
    axb.set_yticks(np.arange(13.2,19,0.4))
    axb.set_ylim(286.1- 273.15, 291.5- 273.15)
    axb.set_ylabel('Temperature (°C)')
    plt.savefig("summary_of_sims.png", dpi=400,format="png")
    plt.savefig("summary_of_sims.pdf",format="pdf")
    fig2=plt.figure(2, figsize=(6,5))
    plt.subplots_adjust(top=0.855)
    grid=plt.GridSpec(3,2, wspace=0.3, hspace=0.4)
    fig2.suptitle("Distribution Metrics of CESM2 LENS Ensemble")
    #stats.probplot(qqyh, dist="norm", plot=plt)
    ax1=fig2.add_subplot(grid[:,0])
    xnorm = np.linspace(-.5, .5, 100)
    ynorm  = stats.norm.pdf(xnorm,scale=np.mean(twTRstd))
    ax1.hist(qqy.flatten(),bins=nbins, density=True,color=cesmcolor)
    ax1.plot(xnorm,ynorm, color=ekf.pcolor,linewidth=1)
    for d in range(4):
        dist=np.mean(twTRstd)*d
        ax1.plot([dist,dist],[0,stats.norm.pdf(dist,scale=np.mean(twTRstd))], color=ekf.pcolor,linewidth=0.5)
        ax1.plot([-dist,-dist],[0,stats.norm.pdf(dist,scale=np.mean(twTRstd))], color=ekf.pcolor,linewidth=0.5)
    
    ax1.set_title("Deviations from Mean \n Throughout Timeseries")
    #ax1.set_title("$ \\bar {a}$")
    #ax1.set_xlabel("Normalized to Time-Varying \n Standard Deviation")
    sy=0
    sx=10
    
    ax1.set_xlabel("Standard Deviation (K)")
    label=axlabels[0]
    trans = mtransforms.ScaledTranslation(-(sx+15)/72, sy/72, fig2.dpi_scale_trans) #-20/72, 7/72
    ax1.text(0.0, 1.0, label, transform=ax1.transAxes + trans,
            fontsize='large', va='bottom')

    label_list=["", "Skewness", "Kurtosis"]
    eqslatex=[" $ \bar{\sigma} = $","$ \bar{\mu_3} / \bar{\sigma^3} = $","$ \bar{\mu_4} / \bar{\sigma^4} $"]
    ax2=fig2.add_subplot(grid[0,1])
    ax2.plot(dates,twTRstd,color=cesmcolor)
    ax2.set_title("Std. Deviation",y=0.95)
    print("Stdev: " + str(np.mean(twTRstd)))
    hts=[np.mean(twTRstd),0,3]
    ax2.plot([dates[0],dates[-1]],[hts[0],hts[0]],ls="--",color=ekf.pcolor)
    ax2.set_xticks(np.arange(1850,2150,50))
    ax2.tick_params(axis='x', labelsize=8)

    endc=35
    X2 = sm.add_constant((dates[:-endc]-1850)/100)
    est = sm.OLS(twTRstd[:-endc], X2)
    est2 = est.fit()
    print(est2.summary2())

    X3 = sm.add_constant(dates)
    esta = sm.OLS(twTRstd, X3)
    esta2 = esta.fit()
    print(esta2.summary2())
    p = esta.fit().params
    x = np.arange(1850, dates[-1])
    ax2.plot(x, p[0] + p[1] * x, ls=":",color=ekf.pcolor)
    print(hts[0])
    print(p[1] * (2050-1850)/(p[0] + p[1] * 1850))
    label=axlabels[1]
    trans = mtransforms.ScaledTranslation(-(sx)/72, sy/72, fig2.dpi_scale_trans) #-20/72, 7/72
    ax2.text(0.0, 1.0, label, transform=ax2.transAxes + trans,
            fontsize='large', va='bottom')

    X3 = sm.add_constant(dates[:-endc])
    esta = sm.OLS(twTRstd[:-endc], X3)
    esta2 = esta.fit()
    p = esta.fit().params
    x = np.arange(1850, dates[-endc])
    ax2.plot(x, p[0] + p[1] * x, ls=":",color=ekf.pcolor)
    print(p[1] * (2100-endc-1850)/(p[0] + p[1] * 1850))
    

    for i in [1,2]:
        ax2=fig2.add_subplot(grid[i,1])
        moment=stats.moment(twTR,moment=i+2,axis=0,nan_policy='omit')
        sdmom=np.divide(moment,twTRstd**(i+2))
        ax2.plot(dates,sdmom,color=cesmcolor)
        ax2.set_title("Std. "+label_list[i],y=0.95)
        ax2.set_xticks(np.arange(1850,2150,50))
        ax2.tick_params(axis='x', labelsize=8)
        ax2.plot([dates[0],dates[-1]],[hts[i],hts[i]],color=ekf.pcolor)
        print(label_list[i]+": " +str(np.mean(sdmom)))
        label=axlabels[i+1]
        trans = mtransforms.ScaledTranslation(-(sx)/72, sy/72, fig2.dpi_scale_trans) #-20/72, 7/72
        ax2.text(0.0, 1.0, label, transform=ax2.transAxes + trans,
            fontsize='large', va='bottom')
    plt.savefig("summary_of_sims_dist.pdf", format="pdf")    
    plt.savefig("summary_of_sims_dist.png", dpi=400,format="png")
    plt.show()


